from matplotlib import rcParams
%matplotlib inline
import pandas as pd
import scanpy as sc
from matplotlib import rcParams
import numpy as np
import scfate as scr
from scfate.genome.sequences import *
from scfate.motifs.score import *
from scfate.genome.FastaAnalyzer import FastaAnalyzer
from lib.HumanTFs import HumanTFs
from sklearn.metrics import roc_auc_score
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
def set_figsize(w, h):
rcParams['figure.figsize'] = w, h
def set_figdpi(dpi):
rcParams['figure.dpi'] = dpi
adata = sc.read('/mnt/znas/icb_zstore01/groups/ml01/workspace/ignacio.ibarra/theislab/retinal_scRNAseq_integration/data/integrated/integration_scgen_unscaled_hvg.h5ad')
adata.obs
set_figsize(10, 5)
sc.pl.umap(adata, color=['cell.type', 'leiden'], ncols=2)
adata
species = 'human'
archetypes_path = 'data/%s_tss_2000_archetypes_hits_top5_mean_scaled.tsv.gz' % species
import os
os.path.getsize(archetypes_path) / 1e6
!ls -ltrh $archetypes_path
archetypes = pd.read_csv(archetypes_path, sep='\t', compression='gzip', index_col=0)
archetypes = archetypes[archetypes.index.isin(set(adata.var.index))]
archetypes.shape
adata = adata[:,adata.var.index.isin(set(archetypes.index))]
adata.shape
adata.varm['archetypes'] = np.array(archetypes.reindex(adata.var.index))
adata.varm['archetypes']
q = 95
qp = q / 100
k ='archetypes.scores.q%i' % q
print(k)
score_archetypes(adata, q=qp, key_obsm=k)
adata.shape
for c in adata.obsm['archetypes.scores.q95']:
adata.obs['score.archetypes.q95.%i' % (int(c) + 1)] = adata.obsm['archetypes.scores.q95'][c]
key_list = [c for c in adata.obs if 'score.archetypes.q95' in c]
hm = sc.get.obs_df(adata, key_list)
# scale
hm = (hm - hm.mean(axis=0)) / hm.std(axis=0)
hm
df = []
for by in ['cell.type', 'leiden']:
hm[by] = adata.obs[by]
for s in set(hm[by]):
print(by, s)
for ci in hm:
if not 'score' in ci:
continue
fg = hm[hm[by] == s][ci]
bg = hm[hm[by] != s][ci]
y_true = np.concatenate([np.repeat(1, len(fg)), np.repeat(0, len(bg))])
y_score = np.concatenate([fg, bg])
y_score = np.where(np.isnan(y_score), 0, y_score)
if sum(y_score) == np.nan:
continue
roc_auc = roc_auc_score(y_true, y_score)
df.append([s, ci, roc_auc, by])
if by in hm:
del hm[by]
df = pd.DataFrame(df, columns=['cluster', 'archetype', 'roc.auc', 'by'])
df.head()
from lib.Archetypes import Archetypes
clu = Archetypes.get_archetypes_clusters(datadir='data')
motif = Archetypes.get_archetypes_motifs(datadir='data')
motif['symbol'] = motif['Motif'].str.split('_').str[0].str.split('.').str[0].str.lower().str.capitalize()
name_by_cluster = clu.set_index('Cluster_ID')['Name'].to_dict()
def plot_motifs_rows_clustermap(fg_rows, log=False):
set_figdpi(100)
set_figsize(5, 10)
nrow = len(fg_rows)
pi = 0
rows_meme = None
for idx in fg_rows:
pi += 1
cluster_name = idx.split(', ')[0]
cluster_id = list(clu[clu['Name'] == cluster_name]['Cluster_ID'])[0]
pfm_id = motif[motif['Cluster_ID'] == cluster_id]['Motif']
# print(cluster_id, list(pfm_id))
pfm_id = list(pfm_id[(pfm_id.str.contains('H11MO') & pfm_id.str.contains('MOUSE')) | (pfm_id.str.contains('_MA'))])
found = len(pfm_id) != 0
# print(pi, pfm_id, 'found=%i' % found)
if found:
pfm_id = pfm_id[0]
else:
pfm_id = list(motif[motif['Cluster_ID'] == cluster_id]['Motif'])[0]
pfm_path = None
if log:
print(pi, pfm_id)
if '_MA' in pfm_id: # Jaspar
pfm_path = 'data/JASPAR/%s.jaspar' % pfm_id.split('_')[-1]
# print(exists(pfm_path), pfm_path)
rows = [r.strip() for r in open(pfm_path)]
motif_id = rows[0].split("\t")[0][1:]
pfm = pd.DataFrame([list(map(int, r.replace("]", "").replace("[", "").split()[1:]))
for r in rows[1:]])
pfm = pfm.T
pfm.columns = 'A', 'C', 'G', 'T'
pfm.index += 1
ppm = pfm.T
elif 'H11MO' in pfm_id: # hocomoco
pfm_path = 'data/HOCOMOCOv11/pcm/%s.pcm' % pfm_id
pfm = pd.read_csv(pfm_path, skiprows=1, sep='\t', header=None)
# print(exists(pfm_path), pfm_path)
pfm.columns = ['A', 'C', 'G', 'T']
ppm = pfm.transpose()
else: # jaspar
meme_path = join('data/Taipale_2013/jolma2013.meme')
rows_meme = [r.strip() for r in open(meme_path)] if rows_meme is None else rows_meme
start = [ri for ri, r in enumerate(rows_meme) if pfm_id in r][0]
end = [ri for ri, r in enumerate(rows_meme[start:]) if 'MOTIF' in r][1]
ppm = [[float(v) for v in p.split(' ') if len(v) != 0] for p in rows_meme[start: start + end][3:-2]]
ppm = pd.DataFrame(ppm, columns=['A', 'C', 'G', 'T']).transpose()
# freqs to probs
for c in ppm:
col_sum = sum(ppm[c])
ppm[c] /= col_sum
ppm.columns = [i for i in range(len(ppm.columns))]
if log:
print(ppm)
ax = plt.subplot(nrow + 1, 4, pi * 4 + 4)
HumanTFs.plot_pwm_model(pfm_id, ppm=ppm, ax=ax)
plt.xlabel('')
# Hide the right and top spines
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
plt.ylabel(cluster_name, rotation=0, va='center', ha='left', labelpad=2.0)
plt.tick_params(left=False, bottom=False, right=False, labelleft=False, labelbottom=False, labelright=False)
plt.subplots_adjust(bottom=.5)
tfs = pd.read_csv('data/hs/TF_Information_all_motifs_plus.txt', sep='\t')
symbol_by_model = tfs[tfs['TF_Status'] == 'D'].set_index('DBID.1')['TF_Name'].to_dict()
tfs = tfs[~tfs['TF_Name'].isin(motif['symbol']) & (tfs['TF_Status'] == 'I')]
tfs['TF_Name_other'] = tfs['DBID.1'].map(symbol_by_model)
tfs = tfs[~pd.isnull(tfs['TF_Name_other'])]
tfs[tfs['TF_Name'] == 'Neurog3']
print(len(set(tfs['TF_Name'])))
# others_by_query = {ref: {t.split('_')[0].lower().capitalize() for t in grp[grp['TF_Status'] == 'I']['DBID.1']} for ref, grp in tfs.groupby('TF_Name')}
motif = Archetypes.get_archetypes_motifs(datadir='data')
motif['symbol'] = motif['Motif'].str.split('_').str[0].str.split('.').str[0].str.lower().str.upper()
others = []
for missing_tf in set(tfs['TF_Name']):
motif_other = motif[motif['symbol'].isin(tfs[tfs['TF_Name'] == missing_tf]['TF_Name_other'])].drop_duplicates('Cluster_ID').copy()
motif_other['symbol'] = missing_tf
others.append(motif_other.drop_duplicates("symbol"))
others = pd.concat(others)
motif = pd.concat([motif, others])
def grouped_obs_mean(adata, group_key, layer=None, gene_symbols=None):
if layer is not None:
getX = lambda x: x.layers[layer]
else:
getX = lambda x: x.X
if gene_symbols is not None:
new_idx = adata.var[idx]
else:
new_idx = adata.var_names
grouped = adata.obs.groupby(group_key)
out = pd.DataFrame(
np.zeros((adata.shape[1], len(grouped)), dtype=np.float64),
columns=list(grouped.groups.keys()),
index=adata.var_names
)
for group, idx in grouped.indices.items():
X = getX(adata[idx])
out[group] = np.ravel(X.mean(axis=0, dtype=np.float64))
return out
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams['figure.dpi'] = 100
rcParams['figure.figsize'] = [20, 20]
res = []
pi = 0
for by in ['leiden', 'cell.type']:
expr_mean = grouped_obs_mean(adata, by)
for k in set(df[df['by'] == by]['cluster']):
plt.subplot(5, 5, pi + 1)
grp = df[(df['cluster'] == k) & (df['by'] == by)]
grp['cluster.number'] = grp['archetype'].str.split('.').str[-1].astype(int)
grp = grp.merge(motif, left_on='cluster.number', right_on='Cluster_ID')
expr_by_gene = (expr_mean[k] - expr_mean.mean(axis=1)).to_dict()
grp['gene.expr'] = grp['symbol'].map(expr_by_gene)
grp = grp.drop_duplicates('symbol')
# plt.scatter(grp['roc.auc'], grp['gene.expr'], s=5)
plt.axhline(y=.0, ls='--')
plt.axvline(x=.5, ls='--')
grp['rank'] = pd.Series(np.sqrt(np.where(grp['gene.expr'] < 0, 0, grp['gene.expr']) *
np.where(grp['roc.auc'] < 0.5, 0, grp['roc.auc'])).flatten(),
index=grp.index).rank(ascending=False)
grp['color'] = np.where((grp['roc.auc'] > 0.5) & (grp['rank'] < 5), 'red', 'gray')
# grp['color'] = np.where((grp['roc.auc'] > 0.5), 'red', 'gray')
grp = grp[~np.isnan(grp['gene.expr'])]
res.append(grp)
# print(k, grp[~np.isnan(grp['gene.expr'])].shape[0], set(grp[~np.isnan(grp['gene.expr'])]['symbol']))
plt.scatter(grp['roc.auc'], grp['gene.expr'], s=5, c=grp['color'])
for ri, r in grp[(grp['color'] == 'red') | (grp['gene.expr'] > .5)].iterrows():
plt.annotate(r['symbol'], (r['roc.auc'], r['gene.expr']), fontsize=10)
plt.xlabel('ROC-AUC')
plt.ylabel('expression')
plt.ylim([-.25, 1.8])
plt.xlim([.4, .9])
plt.title('%s (%s)' % (k, by))
pi += 1
res = pd.concat(res)
plt.subplots_adjust(bottom=.5, hspace=.8, wspace=.6) # right=.6)
# plt.tight_layout()
sel_tfs = set(res[(res['color'] == 'red') | (res['gene.expr'] > .5)]['symbol'])
# print(sel_tfs)
sel_archetypes = motif[motif['symbol'].isin(sel_tfs)]
archetype_by_symbol = motif[motif['symbol'].isin(sel_tfs)].set_index('symbol')['Cluster_ID'].to_dict()
if False:
res['color'] = np.where((res['roc.auc'] > 0.5) & (res['gene.expr'] > .25), 'red', 'gray')
plt.scatter(res['roc.auc'], res['gene.expr'], s=5, c=res['color'])
for ri, r in grp[grp['gene.expr'] > .5].iterrows():
plt.annotate(r['symbol'], (r['roc.auc'], r['gene.expr']), fontsize=6)
plt.xlabel('ROC-AUC')
plt.ylabel('expression [mean]')
plt.title('Expression versus reg. potential')
sc.pl.umap(adata, color=['leiden', 'cell.type'], ncols=2)
res['k'] = res['by'] + '_' + res['cluster']
sel_df = res.sort_values(['rank', 'roc.auc'], ascending=[True, False]).groupby('k').head(5).sort_values(['k', 'rank'])
sel_tfs = set(sel_df['symbol'])
print(len(sel_tfs))
def plot_best_cases_umap(by='cell.type'):
set_figsize(4, 4)
set_figdpi(100)
res['cluster.name'] = res['Cluster_ID'].map(name_by_cluster)
for k in set(res['k']):
if not by in k:
continue
cell_type_name = k.split('_')[1]
adata.obs['is.%s' % cell_type_name] = adata.obs[by] == cell_type_name
grp = res[res['k'] == k].sort_values(['rank', 'roc.auc'], ascending=[True, False]).groupby('k').head(5).sort_values('roc.auc', ascending=False)
set_figdpi(60)
sc.pl.umap(adata, color='is.%s' % cell_type_name, ncols=5, cmap='Set1_r', title=cell_type_name)
set_figdpi(70)
print('TF-expression')
sc.pl.umap(adata, color=list(grp['symbol']), ncols=5, cmap='magma_r', title=grp['cluster.name'], vmin=0)
print('regulatory potential')
sc.pl.umap(adata, color=list(grp['archetype']), ncols=5, cmap='magma_r', title=k + '_' + grp['symbol'], vmin=0)
def plot_best_motifs(by='cell.type'):
import matplotlib
import seaborn as sns
grp = sel_df[sel_df['by'] == by]
best_with_expr = set(grp['archetype'])
# print(best_3_per_case)
sel = res[res['archetype'].isin(best_with_expr) & (res['by'] == by)]
sel['k2'] = sel['cluster'] + sel['archetype']
sel = sel.drop_duplicates('k2')
sel['archetype.name'] = sel['Cluster_ID'].map(name_by_cluster)
sel['archetype.TF'] = sel['archetype.name'] + ', (' + sel['symbol'] + ')'
roc_hm = sel.pivot('cluster', 'archetype.TF', 'roc.auc').transpose()
# roc_hm.index = roc_hm.index.str.split('.').str[-1].astype(int).map(name_by_cluster) + ', (' + roc_hm.index.str.split('.').str[-1].astype(str) + ')'
fg = sns.clustermap(roc_hm, vmin=.5, vmax=.9, cmap='Blues') # yticklabels=False)
plt.subplots_adjust(bottom=.5, left=.4)
fg.gs.update(left=0.05, right=0.3)
plt.setp(fg.ax_heatmap.get_xticklabels(), rotation=90, fontsize=7)
#create new gridspec for the right part
gs2 = matplotlib.gridspec.GridSpec(1, 1, left=0.6, top=.6)
# create axes within this new gridspec
ax2 = fg.fig.add_subplot(gs2[0])
plot_motifs_rows_clustermap(roc_hm.index[fg.dendrogram_row.reordered_ind])
plot_best_motifs(by='cell.type')
plot_best_cases_umap(by='cell.type')
plot_best_motifs(by='leiden')
plot_best_cases_umap(by='leiden')